{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "OYlaRwNu7ojq"
   },
   "source": [
    "# **Homework 2: Phoneme Classification**\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "A7DRC5V7_8A5"
   },
   "source": [
    "Objectives:\n",
    "* Solve a classification problem with deep neural networks (DNNs).\n",
    "* Understand recursive neural networks (RNNs).\n",
    "\n",
    "If you have any questions, please contact the TAs via TA hours, NTU COOL, or email to mlta-2023-spring@googlegroups.com"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "KVUGfWTo7_Oj"
   },
   "source": [
    "# Download Data\n",
    "Download data from google drive, then unzip it.\n",
    "\n",
    "You should have\n",
    "- `libriphone/train_split.txt`: training metadata\n",
    "- `libriphone/train_labels`: training labels\n",
    "- `libriphone/test_split.txt`: testing metadata\n",
    "- `libriphone/feat/train/*.pt`: training feature\n",
    "- `libriphone/feat/test/*.pt`:  testing feature\n",
    "\n",
    "after running the following block.\n",
    "\n",
    "> **Notes: if the google drive link is dead, you can download the data directly from [Kaggle](https://www.kaggle.com/c/ml2023spring-hw2/data) and upload it to the workspace.**\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "pADUiYODJE1O"
   },
   "source": [
    "# Some Utility Functions\n",
    "**Fixes random number generator seeds for reproducibility.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "BsZKgBZQJjaE"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import random\n",
    "\n",
    "def same_seeds(seed):\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed(seed)\n",
    "        torch.cuda.manual_seed_all(seed)\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "    torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_L_4anls8Drv"
   },
   "source": [
    "**Helper functions to pre-process the training data from raw MFCC features of each utterance.**\n",
    "\n",
    "A phoneme may span several frames and is dependent to past and future frames. \\\n",
    "Hence we concatenate neighboring phonemes for training to achieve higher accuracy. The **concat_feat** function concatenates past and future k frames (total 2k+1 = n frames), and we predict the center frame.\n",
    "\n",
    "Feel free to modify the data preprocess functions, but **do not drop any frame** (if you modify the functions, remember to check that the number of frames are the same as mentioned in the slides)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "IJjLT8em-y9G"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "\n",
    "def load_feat(path):\n",
    "    feat = torch.load(path)\n",
    "    return feat\n",
    "\n",
    "def shift(x, n):\n",
    "    if n < 0:\n",
    "        left = x[0].repeat(-n, 1)\n",
    "        right = x[:n]\n",
    "    elif n > 0:\n",
    "        right = x[-1].repeat(n, 1)\n",
    "        left = x[n:]\n",
    "    else:\n",
    "        return x\n",
    "\n",
    "    return torch.cat((left, right), dim=0)\n",
    "\n",
    "def concat_feat(x, concat_n):\n",
    "    assert concat_n % 2 == 1 # n must be odd\n",
    "    if concat_n < 2:\n",
    "        return x\n",
    "    seq_len, feature_dim = x.size(0), x.size(1)\n",
    "    x = x.repeat(1, concat_n)\n",
    "    x = x.view(seq_len, concat_n, feature_dim).permute(1, 0, 2) # concat_n, seq_len, feature_dim\n",
    "    mid = (concat_n // 2)\n",
    "    for r_idx in range(1, mid+1):\n",
    "        x[mid + r_idx, :] = shift(x[mid + r_idx], r_idx)\n",
    "        x[mid - r_idx, :] = shift(x[mid - r_idx], -r_idx)\n",
    "\n",
    "    return x.permute(1, 0, 2).view(seq_len, concat_n * feature_dim)\n",
    "\n",
    "def preprocess_data(split, feat_dir, phone_path, concat_nframes, train_ratio=0.8, random_seed=1213):\n",
    "    class_num = 41 # NOTE: pre-computed, should not need change\n",
    "\n",
    "    if split == 'train' or split == 'val':\n",
    "        mode = 'train'\n",
    "    elif split == 'test':\n",
    "        mode = 'test'\n",
    "    else:\n",
    "        raise ValueError('Invalid \\'split\\' argument for dataset: PhoneDataset!')\n",
    "\n",
    "    label_dict = {}\n",
    "    if mode == 'train':\n",
    "        for line in open(os.path.join(phone_path, f'{mode}_labels.txt')).readlines():\n",
    "            line = line.strip('\\n').split(' ')\n",
    "            label_dict[line[0]] = [int(p) for p in line[1:]]\n",
    "\n",
    "        # split training and validation data\n",
    "        usage_list = open(os.path.join(phone_path, 'train_split.txt')).readlines()\n",
    "        random.seed(random_seed)\n",
    "        random.shuffle(usage_list)\n",
    "        train_len = int(len(usage_list) * train_ratio)\n",
    "        usage_list = usage_list[:train_len] if split == 'train' else usage_list[train_len:]\n",
    "\n",
    "    elif mode == 'test':\n",
    "        usage_list = open(os.path.join(phone_path, 'test_split.txt')).readlines()\n",
    "\n",
    "    usage_list = [line.strip('\\n') for line in usage_list]\n",
    "    print('[Dataset] - # phone classes: ' + str(class_num) + ', number of utterances for ' + split + ': ' + str(len(usage_list)))\n",
    "\n",
    "    max_len = 3000000\n",
    "    X = torch.empty(max_len, 39 * concat_nframes)\n",
    "    if mode == 'train':\n",
    "        y = torch.empty(max_len, dtype=torch.long)\n",
    "\n",
    "    idx = 0\n",
    "    for i, fname in tqdm(enumerate(usage_list)):\n",
    "        feat = load_feat(os.path.join(feat_dir, mode, f'{fname}.pt'))\n",
    "        cur_len = len(feat)\n",
    "        feat = concat_feat(feat, concat_nframes)\n",
    "        if mode == 'train':\n",
    "          label = torch.LongTensor(label_dict[fname])\n",
    "\n",
    "        X[idx: idx + cur_len, :] = feat\n",
    "        if mode == 'train':\n",
    "          y[idx: idx + cur_len] = label\n",
    "\n",
    "        idx += cur_len\n",
    "\n",
    "    X = X[:idx, :]\n",
    "    if mode == 'train':\n",
    "      y = y[:idx]\n",
    "\n",
    "    print(f'[INFO] {split} set')\n",
    "    print(X.shape)\n",
    "    if mode == 'train':\n",
    "      print(y.shape)\n",
    "      return X, y\n",
    "    else:\n",
    "      return X\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "us5XW_x6udZQ"
   },
   "source": [
    "# Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Fjf5EcmJtf4e"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import Dataset\n",
    "\n",
    "class LibriDataset(Dataset):\n",
    "    def __init__(self, X, y=None):\n",
    "        self.data = X\n",
    "        if y is not None:\n",
    "            self.label = torch.LongTensor(y)\n",
    "        else:\n",
    "            self.label = None\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        if self.label is not None:\n",
    "            return self.data[idx], self.label[idx]\n",
    "        else:\n",
    "            return self.data[idx]\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "IRqKNvNZwe3V"
   },
   "source": [
    "# Model\n",
    "Feel free to modify the structure of the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Bg-GRd7ywdrL"
   },
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "\n",
    "class BasicBlock(nn.Module):\n",
    "    def __init__(self, input_dim, output_dim):\n",
    "        super(BasicBlock, self).__init__()\n",
    "\n",
    "        # TODO: apply batch normalization and dropout for strong baseline.\n",
    "        # Reference: https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html (batch normalization)\n",
    "        #       https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html (dropout)\n",
    "        self.block = nn.Sequential(\n",
    "            nn.Linear(input_dim, output_dim),\n",
    "            nn.BatchNorm1d(output_dim),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(p=0.2)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.block(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "class Classifier(nn.Module):\n",
    "    def __init__(self, input_dim, output_dim=41, hidden_layers=1, hidden_dim=256):\n",
    "        super(Classifier, self).__init__()\n",
    "\n",
    "        self.fc = nn.Sequential(\n",
    "            BasicBlock(input_dim, hidden_dim),\n",
    "            *[BasicBlock(hidden_dim, hidden_dim) for _ in range(hidden_layers)],\n",
    "            nn.Linear(hidden_dim, output_dim)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.fc(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "TlIq8JeqvvHC"
   },
   "source": [
    "# Hyper-parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "iIHn79Iav1ri"
   },
   "outputs": [],
   "source": [
    "# data prarameters\n",
    "# TODO: change the value of \"concat_nframes\" for medium baseline\n",
    "concat_nframes = 25   # x the number of frames to concat with, n must be odd (total 2k+1 = n frames)\n",
    "train_ratio = 0.75   # the ratio of data used for training, the rest will be used for validation\n",
    "\n",
    "# training parameters\n",
    "seed = 1213          # random seed\n",
    "batch_size = 512        # batch size\n",
    "num_epoch = 10         # the number of training epoch\n",
    "learning_rate = 1e-4      # learning rate\n",
    "model_path = './model.ckpt'  # the path where the checkpoint will be saved\n",
    "\n",
    "# model parameters\n",
    "# TODO: change the value of \"hidden_layers\" or \"hidden_dim\" for medium baseline\n",
    "input_dim = 39 * concat_nframes  # the input dim of the model, you should not change the value\n",
    "hidden_layers = 24          # the number of hidden layers\n",
    "hidden_dim = 1024           # the hidden dim\n",
    "print(concat_nframes, hidden_layers, hidden_dim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "IIUFRgG5yoDn"
   },
   "source": [
    "# Dataloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "c1zI3v5jyrDn"
   },
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "import gc\n",
    "\n",
    "same_seeds(seed)\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "print(f'DEVICE: {device}')\n",
    "\n",
    "# preprocess data\n",
    "train_X, train_y = preprocess_data(split='train', feat_dir='./libriphone/feat', phone_path='./libriphone', concat_nframes=concat_nframes, train_ratio=train_ratio, random_seed=seed)\n",
    "val_X, val_y = preprocess_data(split='val', feat_dir='./libriphone/feat', phone_path='./libriphone', concat_nframes=concat_nframes, train_ratio=train_ratio, random_seed=seed)\n",
    "\n",
    "# get dataset\n",
    "train_set = LibriDataset(train_X, train_y)\n",
    "val_set = LibriDataset(val_X, val_y)\n",
    "\n",
    "# remove raw feature to save memory\n",
    "del train_X, train_y, val_X, val_y\n",
    "gc.collect()\n",
    "\n",
    "# get dataloader\n",
    "train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)\n",
    "val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "pwWH1KIqzxEr"
   },
   "source": [
    "# Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "CdMWsBs7zzNs"
   },
   "outputs": [],
   "source": [
    "# create model, define a loss function, and optimizer\n",
    "model = Classifier(input_dim=input_dim, hidden_layers=hidden_layers, hidden_dim=hidden_dim).to(device)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
    "\n",
    "best_acc = 0.0\n",
    "for epoch in range(num_epoch):\n",
    "    train_acc = 0.0\n",
    "    train_loss = 0.0\n",
    "    val_acc = 0.0\n",
    "    val_loss = 0.0\n",
    "\n",
    "    # training\n",
    "    model.train() # set the model to training mode\n",
    "    for i, batch in enumerate(tqdm(train_loader)):\n",
    "        features, labels = batch\n",
    "        features = features.to(device)\n",
    "        labels = labels.to(device)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(features)\n",
    "\n",
    "        loss = criterion(outputs, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        _, train_pred = torch.max(outputs, 1) # get the index of the class with the highest probability\n",
    "        train_acc += (train_pred.detach() == labels.detach()).sum().item()\n",
    "        train_loss += loss.item()\n",
    "\n",
    "    # validation\n",
    "    model.eval() # set the model to evaluation mode\n",
    "    with torch.no_grad():\n",
    "        for i, batch in enumerate(tqdm(val_loader)):\n",
    "            features, labels = batch\n",
    "            features = features.to(device)\n",
    "            labels = labels.to(device)\n",
    "            outputs = model(features)\n",
    "\n",
    "            loss = criterion(outputs, labels)\n",
    "\n",
    "            _, val_pred = torch.max(outputs, 1)\n",
    "            val_acc += (val_pred.cpu() == labels.cpu()).sum().item() # get the index of the class with the highest probability\n",
    "            val_loss += loss.item()\n",
    "\n",
    "    print(f'[{epoch+1:03d}/{num_epoch:03d}] Train Acc: {train_acc/len(train_set):3.5f} Loss: {train_loss/len(train_loader):3.5f} | Val Acc: {val_acc/len(val_set):3.5f} loss: {val_loss/len(val_loader):3.5f}')\n",
    "\n",
    "    # if the model improves, save a checkpoint at this epoch\n",
    "    if val_acc > best_acc:\n",
    "        best_acc = val_acc\n",
    "        torch.save(model.state_dict(), model_path)\n",
    "        print(f'saving model with acc {best_acc/len(val_set):.5f}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ab33MxosWLmG"
   },
   "outputs": [],
   "source": [
    "del train_set, val_set\n",
    "del train_loader, val_loader\n",
    "gc.collect()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1Hi7jTn3PX-m"
   },
   "source": [
    "# Testing\n",
    "Create a testing dataset, and load model from the saved checkpoint."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "VOG1Ou0PGrhc"
   },
   "outputs": [],
   "source": [
    "# load data\n",
    "test_X = preprocess_data(split='test', feat_dir='./libriphone/feat', phone_path='./libriphone', concat_nframes=concat_nframes)\n",
    "test_set = LibriDataset(test_X, None)\n",
    "test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ay0Fu8Ovkdad"
   },
   "outputs": [],
   "source": [
    "# load model\n",
    "model = Classifier(input_dim=input_dim, hidden_layers=hidden_layers, hidden_dim=hidden_dim).to(device)\n",
    "model.load_state_dict(torch.load(model_path))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "zp-DV1p4r7Nz"
   },
   "source": [
    "Make prediction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "84HU5GGjPqR0"
   },
   "outputs": [],
   "source": [
    "pred = np.array([], dtype=np.int32)\n",
    "\n",
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    for i, batch in enumerate(tqdm(test_loader)):\n",
    "        features = batch\n",
    "        features = features.to(device)\n",
    "\n",
    "        outputs = model(features)\n",
    "\n",
    "        _, test_pred = torch.max(outputs, 1) # get the index of the class with the highest probability\n",
    "        pred = np.concatenate((pred, test_pred.cpu().numpy()), axis=0)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wyZqy40Prz0v"
   },
   "source": [
    "Write prediction to a CSV file.\n",
    "\n",
    "After finish running this block, download the file `prediction.csv` from the files section on the left-hand side and submit it to Kaggle."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "GuljYSPHcZir"
   },
   "outputs": [],
   "source": [
    "with open('prediction.csv', 'w') as f:\n",
    "    f.write('Id,Class\\n')\n",
    "    for i, y in enumerate(pred):\n",
    "        f.write('{},{}\\n'.format(i, y))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "d2l",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
